{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from langchain.document_loaders import PyPDFLoader\n",
    "from langchain.document_loaders import PyPDFDirectoryLoader\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "import uuid\n",
    "import json\n",
    "import ollama.client as client\n",
    "\n",
    "\n",
    "\n",
    "splitter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size = 800,\n",
    "    chunk_overlap  = 100,\n",
    "    length_function = len,\n",
    "    is_separator_regex = False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-large-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of parameters -> 332.538889 Mn\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "## Roberta based NER\n",
    "# ner = pipeline(\"token-classification\", model=\"2rtl3/mn-xlm-roberta-base-named-entity\", aggregation_strategy=\"simple\")\n",
    "ner = pipeline(\"token-classification\", model=\"dslim/bert-large-NER\", aggregation_strategy=\"simple\")\n",
    "\n",
    "\n",
    "print(\"Number of parameters ->\", ner.model.num_parameters()/1000000, \"Mn\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "def row2NamedEntities(row):\n",
    "    # print(row)\n",
    "    ner_results = ner(row['text'])\n",
    "    metadata = {'chunk_id': row['chunk_id']}\n",
    "    entities = []\n",
    "    for result in ner_results:\n",
    "        entities = entities + [{'name': result['word'], 'entity': result['entity_group'], **metadata}]\n",
    "        \n",
    "    return entities\n",
    "\n",
    "def dfText2DfNE(dataframe):\n",
    "    ## Takes a dataframe from the parsed data and returns dataframe with named entities. \n",
    "    ## The input dataframe must have a text and a chunk_id column. \n",
    "\n",
    "    ## Using swifter for parallelism\n",
    "    ## 1. Calculate named entities for each row of the dataframe. \n",
    "    results = dataframe.apply(row2NamedEntities, axis=1)\n",
    "\n",
    "    ## Flatten the list of lists to one single list of entities. \n",
    "    entities_list = np.concatenate(results).ravel().tolist()\n",
    "\n",
    "    ## Remove all NaN entities\n",
    "    entities_dataframe = pd.DataFrame(entities_list).replace(' ', np.nan)\n",
    "    entities_dataframe = entities_dataframe.dropna(subset=['entity'])\n",
    "\n",
    "    ## Count the number of occurances per chunk id\n",
    "    entities_dataframe = entities_dataframe.groupby(['name', 'entity', 'chunk_id']).size().reset_index(name='count')\n",
    "\n",
    "    return entities_dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "167"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loader = PyPDFLoader(\"./data/GlobalPublicHealth2022.pdf\")\n",
    "# loader = PyPDFDirectoryLoader(\"./data/kesy1dd\")\n",
    "\n",
    "pages = loader.load_and_split(text_splitter=splitter)\n",
    "len(pages)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rows = []\n",
    "for page in pages:\n",
    "    row = {'text': page.page_content, **page.metadata, 'chunk_id': uuid.uuid4().hex}\n",
    "    rows += [row]\n",
    "\n",
    "df = pd.DataFrame(rows)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfne = dfText2DfNE(df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>name</th>\n",
       "      <th>entity</th>\n",
       "      <th>count</th>\n",
       "      <th>chunk_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>486</td>\n",
       "      <td>WHO</td>\n",
       "      <td>ORG</td>\n",
       "      <td>228</td>\n",
       "      <td>02db3a55557341d8ba3851dedd6223ed,077b1c4c99064...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>390</td>\n",
       "      <td>Region</td>\n",
       "      <td>MISC</td>\n",
       "      <td>72</td>\n",
       "      <td>07a0744df31b4bfeaa18460091a48c6f,082f7b2fd6794...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>384</td>\n",
       "      <td>R</td>\n",
       "      <td>MISC</td>\n",
       "      <td>57</td>\n",
       "      <td>088ecaeee2804c10bb2a8eb885930949,1a1455b89c444...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>160</td>\n",
       "      <td>E</td>\n",
       "      <td>MISC</td>\n",
       "      <td>37</td>\n",
       "      <td>045bd2ce17da4b839e1198a6603c3b34,0a342463219b4...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>161</td>\n",
       "      <td>E</td>\n",
       "      <td>ORG</td>\n",
       "      <td>36</td>\n",
       "      <td>02db3a55557341d8ba3851dedd6223ed,0848d9f456f04...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>195</td>\n",
       "      <td>Event Management System</td>\n",
       "      <td>MISC</td>\n",
       "      <td>3</td>\n",
       "      <td>15113bcf8e2f4018a589268bf2e67852,650b458e34044...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>124</td>\n",
       "      <td>China</td>\n",
       "      <td>LOC</td>\n",
       "      <td>3</td>\n",
       "      <td>7a56d66c7a9a4b249c4b5e7a9ad1bcd0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>131</td>\n",
       "      <td>Congo</td>\n",
       "      <td>LOC</td>\n",
       "      <td>3</td>\n",
       "      <td>045bd2ce17da4b839e1198a6603c3b34,d510aaa685194...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>440</td>\n",
       "      <td>South -</td>\n",
       "      <td>MISC</td>\n",
       "      <td>3</td>\n",
       "      <td>66faf6d1a4334bf69bd0300b7653096c,fb187a00ea554...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>121</td>\n",
       "      <td>Chiku</td>\n",
       "      <td>MISC</td>\n",
       "      <td>3</td>\n",
       "      <td>5b9a455f599e4608a50da16e71a71a31,616c2f5978124...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    index                     name entity  count  \\\n",
       "0     486                      WHO    ORG    228   \n",
       "1     390                   Region   MISC     72   \n",
       "2     384                        R   MISC     57   \n",
       "3     160                        E   MISC     37   \n",
       "4     161                        E    ORG     36   \n",
       "..    ...                      ...    ...    ...   \n",
       "95    195  Event Management System   MISC      3   \n",
       "96    124                    China    LOC      3   \n",
       "97    131                    Congo    LOC      3   \n",
       "98    440                  South -   MISC      3   \n",
       "99    121                    Chiku   MISC      3   \n",
       "\n",
       "                                             chunk_id  \n",
       "0   02db3a55557341d8ba3851dedd6223ed,077b1c4c99064...  \n",
       "1   07a0744df31b4bfeaa18460091a48c6f,082f7b2fd6794...  \n",
       "2   088ecaeee2804c10bb2a8eb885930949,1a1455b89c444...  \n",
       "3   045bd2ce17da4b839e1198a6603c3b34,0a342463219b4...  \n",
       "4   02db3a55557341d8ba3851dedd6223ed,0848d9f456f04...  \n",
       "..                                                ...  \n",
       "95  15113bcf8e2f4018a589268bf2e67852,650b458e34044...  \n",
       "96                   7a56d66c7a9a4b249c4b5e7a9ad1bcd0  \n",
       "97  045bd2ce17da4b839e1198a6603c3b34,d510aaa685194...  \n",
       "98  66faf6d1a4334bf69bd0300b7653096c,fb187a00ea554...  \n",
       "99  5b9a455f599e4608a50da16e71a71a31,616c2f5978124...  \n",
       "\n",
       "[100 rows x 5 columns]"
      ]
     },
     "execution_count": 137,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ne = dfne.groupby(['name', 'entity']).agg({'count': 'sum', 'chunk_id': ','.join}).reset_index()\n",
    "df_ne.sort_values(by='count', ascending=False).head(100).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'assesses the likelihood and consequences of an acute public health threat due to exposure \\nto an identified hazard. RRA involves a joint assessment by the WHO Country and Regional \\nOffices and headquarters. It is conducted for events with serious public health implications \\nfollowing pre-defined criteria within WHO. The RRA process provides a forum for the \\ntimely assessment of available data, which takes into account the contextual and hazard-\\nspecific knowledge and feedback of key experts across WHO. It supports a collaborative \\nexpert prioritization of immediate actions in a time-sensitive manner. The finalized RRA \\nreport may be shared with key stakeholders that could contribute to the response. RRA \\nreports have become well accepted documents of high practical value both within WHO'"
      ]
     },
     "execution_count": 140,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pages[12].page_content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def extractConcepts(prompt: str, model='mistral-openorca:latest'):\n",
    "    SYS_PROMPT = (\n",
    "        \"Your task is to extract the key entities mentioned in the users input.\\n\"\n",
    "        \"Entities may include - event, concept, person, place, object, document, organisation, artifact, misc, etc.\\n\"\n",
    "        \"Format your output as a list of json with the following structure.\\n\"\n",
    "        \"[{\\n\"\n",
    "        \"   \\\"entity\\\": The Entity string\\n\"\n",
    "        \"   \\\"importance\\\": How important is the entity given the context on a scale of 1 to 5, 5 being the highest.\\n\"\n",
    "        \"   \\\"type\\\": Type of entity\\n\"\n",
    "        \"}, { }]\"\n",
    "    )\n",
    "    response, context = client.generate(model_name=model, system=SYS_PROMPT, prompt=prompt)\n",
    "    return json.loads(response)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " [\n",
      "   {\n",
      "      \"entity\": \"acute public health events\",\n",
      "      \"importance\": 4,\n",
      "      \"type\": \"event\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"infectious diseases\",\n",
      "      \"importance\": 3,\n",
      "      \"type\": \"concept\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"disasters\",\n",
      "      \"importance\": 2,\n",
      "      \"type\": \"concept\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"EMS\",\n",
      "      \"importance\": 3,\n",
      "      \"type\": \"document\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"RRA reports\",\n",
      "      \"importance\": 2,\n",
      "      \"type\": \"document\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"EIS bulletins\",\n",
      "      \"importance\": 2,\n",
      "      \"type\": \"document\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"DON reports\",\n",
      "      \"importance\": 2,\n",
      "      \"type\": \"document\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"WHO Regions\",\n",
      "      \"importance\": 3,\n",
      "      \"type\": \"organization\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"IHR (2005) framework\",\n",
      "      \"importance\": 4,\n",
      "      \"type\": \"concept\"\n",
      "   },\n",
      "   {\n",
      "      \"entity\": \"WHO\",\n",
      "      \"importance\": 4,\n",
      "      \"type\": \"organization\"\n",
      "   }\n",
      "]"
     ]
    }
   ],
   "source": [
    "res = extractConcepts(prompt = pages[22].page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'entity': 'infectious diseases', 'importance': 4},\n",
       " {'entity': 'disasters', 'importance': 3},\n",
       " {'entity': 'EMS', 'importance': 3},\n",
       " {'entity': 'RRA reports', 'importance': 2},\n",
       " {'entity': 'EIS bulletins', 'importance': 1.5},\n",
       " {'entity': 'DON reports', 'importance': 1},\n",
       " {'entity': 'WHO Regions', 'importance': 2},\n",
       " {'entity': 'IHR (2005) framework', 'importance': 3}]"
      ]
     },
     "execution_count": 168,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "OpenAI@3111",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
